#include <core/class.h>
#include <core/sys.h>
#include <core/errno.h>
#include <core/time.h>
#include <core/kobject.h>
#include <core/waitqueue.h>
#include <task.h>
#include <uaccess.h>
#include <malloc.h>
#include <log.h>

#include <list.h>
#include <atomic.h>

class_impl(kobject_t){
    .refcount = (atomic_t){
        .counter = 0,
    },
    .signals = 0,
};

constructor(kobject_t)
{
    init_waitqueue_head(&this->queue);
}

static void init_poll_wqueue(struct kobj_poll_wqueue_t *wqueue, struct kobj_poll_desc_t *desc, int nr)
{
    wqueue->poll_thread = thread_self();
    wqueue->desc = calloc(1, sizeof(struct kobj_poll_desc_t) * nr);

    copy_from_user((char *)wqueue->desc, (char *)desc, sizeof(struct kobj_poll_desc_t) * nr);
    wqueue->entries = calloc(1, sizeof(struct kobj_wait_queue_entry_t) * nr);
    wqueue->timeout = FALSE;

    for (int i = 0; i < nr; i++)
        init_wait_entry(&wqueue->entries[i].wq_entry);
}

static void free_poll_wqueue(struct kobj_poll_wqueue_t *wqueue)
{
    free(wqueue->entries);
    free(wqueue->desc);
}

static int kobj_poll_wake(struct wait_queue_entry *wq_entry)
{
    struct kobj_wait_queue_entry_t *kentry = container_of(wq_entry, struct kobj_wait_queue_entry_t, wq_entry);
    return autoremove_wake_function(wq_entry);
}

static int object_timer_function(struct timer_t *timer, void *data)
{
    struct kobj_poll_wqueue_t *wqueue = (struct kobj_poll_wqueue_t *)data;
    wqueue->timeout = TRUE;
    thread_resume(wqueue->poll_thread);
    return 0;
}

void sys_object_wait_many(struct kobj_poll_desc_t *desc, int nr, kduration_t *timeout)
{
    struct kobj_poll_wqueue_t wqueue = {};
    init_poll_wqueue(&wqueue, desc, nr);

    int count = 0;
    int null_poll_count = 0;
    timer_t *timer = NULL;

    if (timeout != NULL)
    {
        if (*timeout == 0)
        {
            wqueue.timeout = TRUE;
        }
        else
        {
            timer = new (timer_t);
            timer_init(timer, object_timer_function, &wqueue);
            timer_start_now(timer, ns_to_ktime(*timeout));
        }
    }

    for (;;)
    {
        for (int i = 0; i < nr; i++)
        {
            kobject_t *obj = dynamic_cast(kobject_t)(slot_get(process_self(), wqueue.desc[i].slot));

            if (!obj->poll)
            {
                null_poll_count++;
                continue;
            }

            uint32_t mask = obj->poll(obj);

            if (mask & wqueue.desc[i].signals)
            {
                count++;
            }

            if (!count && wqueue.entries[i].wait_flag == 0)
            {
                wqueue.entries[i].wait_flag = 1; // Don't call add_wait_queue twice, otherwise list will be borken
                init_waitqueue_func_entry(&wqueue.entries[i].wq_entry, kobj_poll_wake);
                add_wait_queue(&obj->queue, &wqueue.entries[i].wq_entry);
            }
        }

        if (count)
            break;
        if (null_poll_count == nr)
            break;
        if (wqueue.timeout)
            break;
        thread_suspend(thread_self());
    }

    if (timer)
    {
        timer_cancel(timer);
        delete (timer);
    }

    for (int i = 0; i < nr; i++)
    {
        kobject_t *obj = slot_get(process_self(), wqueue.desc[i].slot);
        wqueue.desc[i].rsignals = wqueue.desc[i].signals & obj->signals;
        remove_wait_queue(&obj->queue, &wqueue.entries[i].wq_entry);
    }

    free_poll_wqueue(&wqueue);
}

void object_signal(kobject_t *obj, uint32_t set_mask, uint32_t clear_mask)
{
    obj->signals &= ~clear_mask;
    obj->signals |= set_mask;

    wake_up(&obj->queue);
}

void sys_object_signal(int slot, uint32_t set_mask, uint32_t clear_mask)
{
    kobject_t *obj = dynamic_cast(kobject_t)(slot_get(process_self(), slot));
    object_signal(obj, set_mask, clear_mask);
}